# 对掩码区域内部图像进行填充，填充为白色圆形 避免白色的
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import math
from tqdm import tqdm



def mask_white(img, mask):
    """
    在掩码区域内部填充最大轮廓对应的圆形。本函数读入一幅图像和一个掩码，然后在图像上的最大掩码对应区域的中心填充圆形。

    :param img: 输入的彩色图像，数据类型为NumPy数组。
    :param mask: 输入的掩码图像，是一个二值图像，数据类型为NumPy数组。
    :return: 经过处理后返回带有填充圆形的图像。
    """
    # 确保掩码是二值图像，且为单通道8位图像
    if len(mask.shape) != 2 or mask.dtype != np.uint8:
        raise ValueError("Mask must be a single-channel 8-bit image.")

    # 在掩码图像中寻找轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 寻找最大轮廓
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)

        # 计算最大轮廓的几何中心
        M = cv2.moments(largest_contour)
        if M['m00'] != 0:
            cx = int(M['m10']/M['m00'])
            cy = int(M['m01']/M['m00'])
        else:
            # 如果轮廓面积为零（避免除以零），取轮廓的第一个点
            cx, cy = largest_contour[0][0]

        # 计算最大轮廓的外接矩形，用于确定圆的半径
        _, _, w, h = cv2.boundingRect(largest_contour)
        radius = min(w, h) // 3

        # 在原始图像上填充圆形，圆心为最大轮廓的几何中心，颜色为白色(255, 255, 255)，填充实心圆
        cv2.circle(img, (cx, cy), radius, (255, 255, 255), -1)
        patch_circle(img,(cy,cx),radius*3//1.8)

    # 返回处理后的图像
    return img

def create_circle_mask(center,radius,xy):
    #判断是否在圆内
    return 1 if (xy[0]-center[0])**2+(xy[1]-center[1])**2 < radius**2 else 0
    

def patch_circle(img,center,radius):
    #对图像进行圆形填充
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            if not create_circle_mask(center,radius,(i,j)):
                img[i,j] = 255
    return img



def draw_circle(img,x,y,r):
    cv2.circle(img, (x,y), r, (0, 0, 255), 1)
    return img


def pair_gen(rgb_path, mask_path):
    """
    生成图像对
    :param rgb_path: RGB图像路径
    :param mask_path: 掩码图像路径
    :return: RGB图像，掩码图像
    """
    rgb = cv2.imread(rgb_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    return rgb, mask


if __name__ == '__main__':
    rgb_dir =r'E:\pose\datasets\obj_ac\obj_000000\rgb'
    mask_dir =r'E:\pose\datasets\obj_ac\obj_000000\mask'
    rgb_save_dir = r'E:\pose\datasets\obj_ac\obj_000000\rgb_gaussian'

    if not os.path.exists(rgb_save_dir):
        os.makedirs(rgb_save_dir)
    
    length = len(os.listdir(rgb_dir))
    for i in tqdm(range(length)):
        rgb_path = os.path.join(rgb_dir, f'{str(i)}.jpg')
        mask_path = os.path.join(mask_dir, f'{str(i)}.jpg')
        rgb_save_path = os.path.join(rgb_save_dir, f'{str(i)}.jpg')
        rgb, mask = pair_gen(rgb_path, mask_path)
        rgb = mask_white(rgb, mask)
        cv2.imwrite(rgb_save_path, rgb)